import argparse
import math

import gymnasium as gym
import numpy as np

from sac.replay_memory import ReplayMemory
from sac.sac import SAC
from sac.mpo import MPO
from model import EnsembleDynamicsModel
from model import EnsembleVariationalDynamicsModel
from predict_env import PredictEnv
from predict_env import VarPredictEnv
from sample_env import EnvSampler
import matplotlib.pyplot as plt
import pickle
import torch

def readParser():
    parser = argparse.ArgumentParser(description='MBPO')
    parser.add_argument('--env_name', default="Walker2d-v4",
                        help='Mujoco Gym environment')
    parser.add_argument('--seed', type=int, default=123456, metavar='N',
                        help='random seed (default: 123456)')
    parser.add_argument('--agent', default="SAC",
                        help='Actor-Critic Optimization')
    parser.add_argument('--model', default="VMBPO",
                        help='Model Optimization')
    parser.add_argument('--use_decay', type=bool, default=True, metavar='G',
                        help='Weight decay for network.')
    parser.add_argument('--gamma', type=float, default=0.99, metavar='G',
                        help='discount factor for reward (default: 0.99)')
    parser.add_argument('--tau', type=float, default=0.005, metavar='G',
                        help='target smoothing coefficient(τ) (default: 0.005)')
    parser.add_argument('--alpha', type=float, default=0.2, metavar='G',
                        help='Temperature parameter for entropy scaling (default: 0.2)')
    parser.add_argument('--beta', type=float, default=100.0, metavar='G',
                        help='Risk parameter initialization for KL scaling (default: 100.0)')
    parser.add_argument('--policy', default="Gaussian",
                        help='Policy Type: Gaussian | Deterministic (default: Gaussian)')
    parser.add_argument('--target_update_interval', type=int, default=1, metavar='N',
                        help='Value target update per no. of updates per step (default: 1)')
    parser.add_argument('--automatic_entropy_tuning', type=bool, default=False, metavar='G',
                        help='Automatically adjust alpha')
    parser.add_argument('--automatic_beta_tuning', type=bool, default=True, metavar='G',
                        help='Automatically adjust beta')
    parser.add_argument('--hidden_size', type=int, default=256, metavar='N',
                        help='hidden size (default: 256)')
    parser.add_argument('--lr', type=float, default=0.0003, metavar='G',
                        help='learning rate (default: 0.0003)')
    parser.add_argument('--num_networks', type=int, default=7, metavar='E',
                        help='ensemble size (default: 7)')
    parser.add_argument('--num_elites', type=int, default=5, metavar='E',
                        help='elite size (default: 5)')
    parser.add_argument('--pred_hidden_size', type=int, default=200, metavar='E',
                        help='hidden size for predictive model')
    parser.add_argument('--reward_size', type=int, default=1, metavar='E',
                        help='environment reward size')
    parser.add_argument('--experience_replay_size', type=int, default=1000000, metavar='N',
                        help='size of experience buffer')
    parser.add_argument('--model_replay_size', type=int, default=400000, metavar='N',
                        help='size of model buffer')
    parser.add_argument('--model_retain_epochs', type=int, default=1, metavar='A',
                        help='retain epochs')
    parser.add_argument('--model_train_freq', type=int, default=250, metavar='A',
                        help='frequency of training')
    parser.add_argument('--rollout_batch_size', type=int, default=100000, metavar='A',
                        help='rollout number M')
    parser.add_argument('--epoch_length', type=int, default=1000, metavar='A',
                        help='steps per epoch')
    parser.add_argument('--num_epoch', type=int, default=300, metavar='A',
                        help='total number of epochs')
    parser.add_argument('--min_pool_size', type=int, default=1000, metavar='A',
                        help='minimum pool size')
    parser.add_argument('--real_ratio', type=float, default=0.05, metavar='A',
                        help='ratio of env samples / model samples')
    parser.add_argument('--num_train_repeat', type=int, default=20, metavar='A',
                        help='times to training policy per step')
    parser.add_argument('--policy_train_batch_size', type=int, default=256, metavar='A',
                        help='batch size for training policy')
    parser.add_argument('--init_exploration_steps', type=int, default=5000, metavar='A',
                        help='exploration steps initially')
    parser.add_argument('--max_path_length', type=int, default=1000, metavar='A',
                        help='max length of path')
    parser.add_argument('--cuda', default=True, action="store_true",
                        help='run on CUDA (default: True)')
    return parser.parse_args()

def train(args, env_sampler, predict_env, agent, env_pool, model_pool):
    beta = np.array([args.beta])
    # log_beta = torch.zeros(1, requires_grad=True, device=torch.device('cuda')) + math.log(args.beta)
    log_beta = torch.Tensor([math.log(args.beta)]).to('cuda')
    log_beta.requires_grad = True
    print(log_beta)
    beta_optim = torch.optim.Adam([log_beta], lr=args.lr)
    experiment_num = 5

    exploration_before_start(args, env_sampler, env_pool, agent)
    graph_G = []
    betas = []
    KLs = []
    for epoch_step in range(args.num_epoch):
        if(epoch_step % 10 == 9):
            with open("{0}_return_{1}_{2}.pickle".format(1.0, experiment_num, args.env_name), "wb") as fp:
                pickle.dump(graph_G, fp)
        if(epoch_step % 25 == 24):
            torch.save(predict_env.model.ensemble_model.state_dict(),
                       "{0}_{1}_{2}.pt".format(args.env_name, 'prior_dynamics', experiment_num))
            torch.save(predict_env.var_model.ensemble_model.state_dict(),
                       "{0}_{1}_{2}.pt".format(args.env_name, 'posterior_dynamics', experiment_num))
            torch.save(agent.critic.state_dict(),
                       "{0}_{1}_{2}.pt".format(args.env_name, 'critic', experiment_num))
            torch.save(agent.policy.state_dict(),
                       "{0}_{1}_{2}.pt".format(args.env_name, 'post_policy', experiment_num))
            # plt.title('betas {0}'.format(experiment_num))
            # plt.plot(betas)
            # plt.show()
            #
            # plt.title('KL {0}'.format(experiment_num))
            # plt.plot(KLs)
            # plt.show()

            # plt.title('Return {0}'.format(experiment_num))
            # plt.plot(graph_G)
            # plt.show()
            # with open("beta_return_{1}.pickle".format(1.0, experiment_num), "wb") as fp:
            #     pickle.dump(betas, fp)
            # with open("KL_return_{1}.pickle".format(1.0, experiment_num), "wb") as fp:
            #     pickle.dump(KLs, fp)
        # print(epoch_step)
        for i in range(args.epoch_length):
            if i > 0 and i % args.model_train_freq == 0 and args.real_ratio < 1.0:
                if (args.model == 'MBPO'):
                    train_predict_model(env_pool, predict_env.model)
                elif (args.model == 'VMBPO'):
                    train_predict_model(env_pool, predict_env.model)
                    train_variational_predict_model(env_pool, predict_env.var_model, agent.critic, agent.q_c, beta)
                rollout_model(args, predict_env, agent, model_pool, env_pool)
            #Collect data
            cur_state, action, next_state, reward, done = env_sampler.sample(agent)
            env_pool.push(cur_state, action, reward, next_state, done)

            # Train agent
            train_policy_repeats(args, env_pool, model_pool, agent)

            #Beta optimization
            if (len(model_pool) > 0 and args.automatic_beta_tuning and beta[0] < 10000000):
                env_state, env_action, _, _, _ = model_pool.sample(256)
                # env_next_state, _, _ = predict_env.step(env_state, env_action)

                inputs = np.concatenate((env_state, env_action), axis=-1)

                q_mu, q_var = predict_env.var_model.predict(inputs)
                q_mu += env_state

                ensemble_samples = q_mu + np.random.normal(size=q_mu.shape) * q_var
                model_idxes = np.random.choice(predict_env.var_model.elite_model_idxes)
                env_next_state = ensemble_samples[model_idxes]

                q_mu = q_mu[model_idxes]
                q_var = q_var[model_idxes]

                log_q, _ = predict_env._get_logprob(env_next_state, q_mu, q_var)

                p_mu, p_var = predict_env.model.predict(inputs)
                p_mu[:, :, 1:] += env_state
                p_mu = p_mu[:, :, 1:]
                p_var = p_var[:, :, 1:]
                model_idxes = np.random.choice(predict_env.model.elite_model_idxes)
                p_mu = p_mu[model_idxes]
                p_var = p_var[model_idxes]

                log_p, _ = predict_env._get_logprob(env_next_state, p_mu, p_var)

                KL = np.mean(log_q - log_p)
                # beta -= 0.001 * (0.1 - KL)

                beta_loss = (log_beta * (10 - KL)).mean()

                beta_optim.zero_grad()
                beta_loss.backward()
                beta_optim.step()

                beta = np.array([log_beta.exp().item()])

        print(KL)
        print(beta[0])

        # KLs.append(KL)
        # betas.append(beta[0])
        # testing
        env_sampler.current_state = None
        sum_reward = 0
        for test_iter in range(5):
            done = False
            test_step = 0
            while (not done) and (test_step != args.max_path_length):
                cur_state, action, next_state, reward, done = env_sampler.sample(agent, eval_t=True)
                sum_reward += reward
                test_step += 1
        graph_G.append(sum_reward/5)
        print(epoch_step, sum_reward/5)

def exploration_before_start(args, env_sampler, env_pool, agent):
    for i in range(args.init_exploration_steps):
        cur_state, action, next_state, reward, done = env_sampler.sample(agent)
        env_pool.push(cur_state, action, reward, next_state, done)

def train_predict_model(env_pool, model):
    # Get all samples from environment
    state, action, reward, next_state, done = env_pool.sample(len(env_pool))
    delta_state = next_state - state
    inputs = np.concatenate((state, action), axis=-1)
    labels = np.concatenate((np.reshape(reward, (reward.shape[0], -1)), delta_state), axis=-1)

    model.train(inputs, labels, batch_size=1024, holdout_ratio=0.2)

def train_variational_predict_model(env_pool, model, Q, q_c, beta):
    # Get all samples from environment
    state, action, reward, next_state, done = env_pool.sample(len(env_pool))
    delta_state = next_state - state
    inputs = np.concatenate((state, action), axis=-1)
    labels = delta_state
    model.train(inputs, labels, state, action, reward, next_state, done, Q, q_c, batch_size=256, holdout_ratio=0.2, beta = beta)

def rollout_model(args, predict_env, agent, model_pool, env_pool):
    state, action, reward, next_state, done = env_pool.sample_all_batch(args.rollout_batch_size)
    action = agent.select_action(state)

    # next_states, rewards = predict_env.step(state, action)
    next_states, rewards, terminals = predict_env.step(state, action)
    terminals = done[:,None]

    model_pool.push_batch([(state[j], action[j], rewards[j], next_states[j], terminals[j]) for j in range(state.shape[0])])


def train_policy_repeats(args, env_pool, model_pool, agent):
    for i in range(args.num_train_repeat):
        env_batch_size = int(args.policy_train_batch_size * args.real_ratio)
        model_batch_size = args.policy_train_batch_size - env_batch_size

        env_state, env_action, env_reward, env_next_state, env_done = env_pool.sample(int(env_batch_size))

        if model_batch_size > 0 and len(model_pool) > 0:
            model_state, model_action, model_reward, model_next_state, model_done = model_pool.sample_all_batch(int(model_batch_size))
            batch_state, batch_action, batch_reward, batch_next_state, batch_done = np.concatenate((env_state, model_state), axis=0), \
                                                                                    np.concatenate((env_action, model_action),
                                                                                                   axis=0), np.concatenate(
                (np.reshape(env_reward, (env_reward.shape[0], -1)), model_reward), axis=0), \
                                                                                    np.concatenate((env_next_state, model_next_state),
                                                                                                   axis=0), np.concatenate(
                (np.reshape(env_done, (env_done.shape[0], -1)), model_done), axis=0)
        else:
            batch_state, batch_action, batch_reward, batch_next_state, batch_done = env_state, env_action, env_reward, env_next_state, env_done

        batch_reward, batch_done = np.squeeze(batch_reward), np.squeeze(batch_done)
        batch_done = (~batch_done).astype(int)
        agent.update_parameters((batch_state, batch_action, batch_reward, batch_next_state, batch_done), args.policy_train_batch_size, i)

def main(args=None):
    if args is None:
        args = readParser()

    # Initial environment
    env = gym.make(args.env_name)

    # Actor-critic
    if(args.agent == 'SAC'):
        agent = SAC(env.observation_space.shape[0], env.action_space, args)
    else:
        agent = MPO(env.observation_space.shape[0], env.action_space, args)

    # Dynamics model
    state_size = np.prod(env.observation_space.shape)
    action_size = np.prod(env.action_space.shape)

    if(args.model == 'MBPO'):
        env_model = EnsembleDynamicsModel(args.num_networks, args.num_elites, state_size, action_size, args.reward_size, args.pred_hidden_size,
                                      use_decay=args.use_decay)
        predict_env = PredictEnv(env_model, args.env_name)
    elif(args.model == 'VMBPO'):
        env_model = EnsembleDynamicsModel(args.num_networks, args.num_elites, state_size, action_size, args.reward_size,
                                          args.pred_hidden_size,
                                          use_decay=args.use_decay)
        var_model = EnsembleVariationalDynamicsModel(args.num_networks, args.num_elites, state_size, action_size, args.pred_hidden_size,
                                      use_decay=args.use_decay)
        predict_env = VarPredictEnv(env_model, var_model, args.env_name)
    else:
        env_model = None
        predict_env = PredictEnv(env_model, args.env_name)

    # Initial pool for env
    env_pool = ReplayMemory(args.experience_replay_size)
    # Initial pool for model
    model_pool = ReplayMemory(args.model_replay_size)

    # Sampler of environment
    env_sampler = EnvSampler(env, max_path_length=args.max_path_length)

    train(args, env_sampler, predict_env, agent, env_pool, model_pool)


if __name__ == '__main__':
    main()
